import numpy as np
import string, os, torch
from torch.utils.data import DataLoader
import torch.nn as nn
import sampling
from models_datasets import SingleDataset, generate_single_batch,PredictDataset2,PredictModel_base,PredictModel_sec2last,PredictModel_word2vec,generate_predict_batch2,PredictModel_attention,PredictDataset2_attn, generate_predict_batch2_attn, SingleDataset_attn, generate_single_batch_attn, ByLengthSampler,PredictModel_lastlayer
from models_datasets import  ContrastiveDataset, generate_contrast_batch, ReferenceDataset, generate_reference_batch, ContrastiveModel, PredictModel_rsm
from bow_data import get_data
# from torch.utils.tensorboard import SummaryWriter
import csv
import math
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.model_selection import KFold,cross_val_score,GridSearchCV
import collections
from sklearn.linear_model import LogisticRegression
import time
from train_lda import train_lda_unsupervised
from datetime import timedelta
import pickle 
BATCH_SIZE = 512  # change for attn run 64 # 512->change 128 may 20 ->may 21 change from 128 to 8 for nce c = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


'''
Helpers
'''
def choose_optimizer(model, lr, opt_type, use_scheduler=True, w_decay=0.0):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    if(opt_type == "sgd"):
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)#momentum=0.009, 
    elif(opt_type == "amsgrad"):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay, amsgrad=True)
    elif(opt_type == 'adagrad'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    elif(opt_type == 'adadelta'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    elif(opt_type == 'rms'):
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009, weight_decay=w_decay)

    if(use_scheduler):
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5, min_lr=1e-7)
        return(optimizer, scheduler)
    else:
        return(optimizer)

def read_csv(filename):
    X = []
    Y = []
    with open(filename) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        for row in csv_reader:
            if row[1]=='Label':
                continue
            X += [row[0]]
            Y += [int(row[1])]
    return(X, Y)

'''
For Contrastive

'''
def reference_embeddings(tokens, counts, ref_tokens, ref_counts, model, tol=0.05, take_exp=True):
    nexamples = len(tokens)
    nrefs = len(ref_tokens)
    dataset = ReferenceDataset(tokens, counts, ref_tokens, ref_counts)
    data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=False, collate_fn=generate_reference_batch)

    representation = None
    
    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Making predictions")
        for text_1, offset_1, text_2, offset_2 in data_loader:
            predictions = model(text_1.to(device), offset_1.to(device), text_2.to(device), offset_2.to(device)).squeeze(1).float()
            if(take_exp):
                predictions = torch.exp(predictions)

            if(tol > 0):
                tol_tensor = ((1.0-tol)/tol)*torch.ones_like(predictions)
                x = torch.min(predictions, tol_tensor).cpu()
            else:
                x = predictions.cpu()
                
            all_preds.append(x)

        print("Tensorizing result")
        representation = torch.cat(all_preds)
        representation = representation.view(nexamples, nrefs)
    
    return(representation)
def extract_embedding(data_path, model, reference_embed=True, dim=100):
    # model = torch.load(model_file)

    _, train, test, unsup, _ = get_data(data_path)

    # 1. training data
    train_tokens = train['tokens']
    train_counts = train['counts']

    # 2. testing set
    test_tokens = test['tokens']
    test_counts = test['counts']

    X_test, X_train = None, None
    if(reference_embed):
        unsup_tokens = unsup['tokens']
        unsup_counts = unsup['counts']

        ## Create representative documents
        inds = np.random.choice(len(unsup_tokens), dim, replace=False)
        ref_tokens = [unsup_tokens[i] for i in inds]
        ref_counts = [unsup_counts[i] for i in inds]

        print("Creating contrastive representation for testing documents.")
        X_test = reference_embeddings(tokens=test_tokens, 
                                        counts=test_counts, 
                                        ref_tokens=ref_tokens, 
                                        ref_counts=ref_counts, 
                                        model=model)
        
        print("Creating contrastive representation for training documents.")
        X_train = reference_embeddings(tokens=train_tokens, 
                                            counts=train_counts, 
                                            ref_tokens=ref_tokens, 
                                            ref_counts=ref_counts, 
                                            model=model)
    else:
        print("Creating contrastive representation for testing documents.")
        X_test = model_embeddings(test_tokens, test_counts, model)
        
        print("Creating contrastive representation for training documents.")
        X_train = model_embeddings(train_tokens, train_counts, model)
    
    return(X_train.numpy(), X_test.numpy())


'''
Generate embeddings from neural network given tokens and counts 

Using the get_embedding function in model
To customize embedding, should change specific get_embedding function in the model declaration in
models_datasets.py file

output format, (ndoc,embedding dim)
'''

#for predict model, this would be last layer softmaxed long vector. 
def model_embeddings(tokens, counts, model,Y=None,custom_embedding=False):
    # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    # changed this part, add attention
    # only remove offset by changing the dataloader
    use_attention = isinstance(model, PredictModel_attention)
    if use_attention:
        dataset = SingleDataset_attn(tokens, counts,Y)
        data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=2, pin_memory=True, collate_fn=generate_single_batch_attn,
                                 batch_sampler=ByLengthSampler(dataset, batchsize=BATCH_SIZE, key=None, shuffle=False))

        with torch.no_grad():
            all_preds = []
            all_labels = []
            model.to(device)
            model.eval()
            print("Getting embeddings")
            for text,y in data_loader:
                embedding = model.get_embedding(text.to(device))
                all_preds.append(embedding.cpu())
                all_labels.append(y.cpu())
            print("Tensorizing result")
            representation = torch.cat(all_preds)
            labels = torch.cat(all_labels)
        return representation.numpy(), labels.numpy()

    else:
        dataset = SingleDataset(tokens, counts)
        data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=False, collate_fn=generate_single_batch)

        with torch.no_grad():
            all_preds = []
            model.to(device)
            model.eval()
            print("Getting embeddings")
            print(f'custom embedding is {custom_embedding}')
            for text, offset in data_loader:
                if custom_embedding:
                    softmax = nn.Softmax(dim=1)
                    embedding = model(text.to(device), offset.to(device))
                    embedding = softmax(embedding)
                else:
                    embedding = model.get_embedding(text.to(device), offset.to(device))
                all_preds.append(embedding.cpu())

            print("Tensorizing result")
            representation = torch.cat(all_preds)
        
        return representation.numpy()


'''
Generate BOW embeddings given tokens and counts 

output format, (ndoc, dim of vocab)
'''
def BOW_embeddings(datapath,tokens,counts):
    vocab, _, _, _, _ = get_data(datapath)
    vocab_size=len(vocab)
    embeddings=[]
    for token,count in zip(tokens,counts):
        embedding=np.zeros(vocab_size)
        for t,c in zip(token[0], count[0]):
            embedding[t]=c
        embeddings.append(embedding)
    return embeddings

'''
Generate Word2Vec embeddings given tokens and counts

Get self-trained word2vec matrix, and then get embedding for each word and then take average

output format, (ndoc, dim of word2ve self-trained embeddings 300)
'''
def Word2Vec_embeddings(datapath,tokens,counts,word2vec_emb_size=300,word2vec_matrix=None,random = False):
    vocab, _, _, _, _ = get_data(datapath)
    vocab_size=len(vocab)
    if word2vec_matrix is None:
        word2vec_matrix_path = datapath+'/word2vec_matrix_trained_embsize'+str(word2vec_emb_size)+'.npy'
        with open(word2vec_matrix_path,'rb') as f:
            #directly load pre-trained word2vec matrix from file
            word2vec_matrix=np.load(f)

    embeddings=[]
    for token,count in zip(tokens,counts):
        embedding_sum=np.zeros(word2vec_emb_size)
        for t,c in zip(token[0], count[0]):
            if random:
                embedding_sum+=np.random.rand(word2vec_emb_size)*c
            else:
                embedding_sum+=word2vec_matrix[t]*c
            
        embedding_avg=embedding_sum/sum(count[0])
        embeddings.append(embedding_avg)
    return embeddings


'''
Generate LDA embeddings given tokens and counts
load pretrained lda model on unsupervised set from 'data/dim_[REPDIMENSION]_lda_model.sav'
output format, (ndoc, lda dim)
'''
def LDA_embedding(datapath,tokens,counts,lda_emb_size = 50):
    
    filename = datapath+'/dim_'+str(lda_emb_size)+'_lda_model.sav'
    if not os.path.exists(filename):
        print(f"LDA embedding for dim {lda_emb_size} does not exist, running train...")
        train_lda_unsupervised(datapath=datapath,lda_dim=lda_emb_size)

    lda = pickle.load(open(filename, 'rb'))

    ndoc = len(tokens)
    vocab, _, _, _, _ = get_data(datapath)
    vocab_size=len(vocab)
    X=np.zeros((ndoc,vocab_size))
    for i,(token,count) in enumerate(zip(tokens,counts)):
        for t,c in zip(token[0], count[0]):
            X[i][t]=c

    embeddings = lda.transform(X)
    return embeddings



'''
UNSUPERVISED Phase

Run unsupervised learning to train predict model
'''
def train_model(data_path,results_folder,t, c_dim, h_dim,representation_dim, nepochs, lr, embed_dim, opt_type,w_decay,model_id=None, landmarks=0, dropout_p=0.5, n_layers=3, nfolds=1, resample=0, save_freq=25, use_scheduler=True, pretrained_vectors=False, temp_model_folder=None, prev_model_file=None, presampled_docs_file=None, synthetic_args=None):
    if(not os.path.isdir(results_folder)):
        os.mkdir(results_folder)

    vocab, train, test, unsup, valid = get_data(data_path)
    vocab_size = len(vocab)

    ## Get data
   
    unsup_tokens = unsup['tokens']
    unsup_counts = unsup['counts']

    valid_tokens = valid['tokens']
    valid_counts = valid['counts']

    ## Load up pretrained vectors
    vectors = None
    if(pretrained_vectors):
        pretrained_vectors_file = os.path.join(data_path, 'skipEmbeddings.npy')
        vectors = torch.from_numpy(np.load(pretrained_vectors_file)).float()
        vocab_size, embed_dim = vectors.shape

    ## Build  model
    print("building model...")
    temp_file = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)) + list("_model_temp.pt"))
    if(temp_model_folder):
        temp_file = os.path.join(temp_model_folder, temp_file)


    ## Define sampling function for experiments only
    # No synthetic documents
    def sample_documents_experiments(model_id):
        print("Sampling predict documents...")
        if model_id=="attention":
            #TO BE IMPLEMENTED
            raw_train_dataset = sampling.real_predict_document(unsup_tokens,unsup_counts,t=t)
            raw_valid_dataset = sampling.real_predict_document(valid_tokens,valid_counts,t=t)

            train_dataset = PredictDataset2_attn(raw_train_dataset[0], raw_train_dataset[1])
            valid_dataset = PredictDataset2_attn(raw_valid_dataset[0], raw_valid_dataset[1])

            train_loader = DataLoader(dataset=train_dataset, batch_size=1, collate_fn=generate_predict_batch2_attn,
                                      batch_sampler=ByLengthSampler(train_dataset, batchsize=BATCH_SIZE,shuffle=True,ignore_len=2),
                                      num_workers=2, pin_memory=True)
            valid_loader = DataLoader(dataset=valid_dataset, batch_size=1, collate_fn=generate_predict_batch2_attn,
                                      batch_sampler=ByLengthSampler(valid_dataset, batchsize=BATCH_SIZE,shuffle=False,ignore_len=2),
                                      num_workers=2, pin_memory=True)
            return(train_loader, valid_loader)
            print("none")
        elif model_id=="contrastive":
            raw_train_dataset = sampling.even_contrastive_documents(unsup_tokens, unsup_counts, nfolds=1)
            raw_valid_dataset = sampling.even_contrastive_documents(valid_tokens, valid_counts, nfolds=1)
            train_dataset = ContrastiveDataset(raw_train_dataset[0], raw_train_dataset[1], raw_train_dataset[2], raw_train_dataset[3], raw_train_dataset[4])
            valid_dataset = ContrastiveDataset(raw_valid_dataset[0], raw_valid_dataset[1], raw_valid_dataset[2], raw_valid_dataset[3], raw_valid_dataset[4])

            valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, collate_fn=generate_contrast_batch)
            train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, collate_fn=generate_contrast_batch, shuffle=True)
            return(train_loader, valid_loader)

        else:

            raw_train_dataset = sampling.real_predict_document(unsup_tokens,unsup_counts,t=t)
            raw_valid_dataset = sampling.real_predict_document(valid_tokens,valid_counts,t=t)

            train_dataset = PredictDataset2(raw_train_dataset[0], raw_train_dataset[1])
            valid_dataset = PredictDataset2(raw_valid_dataset[0], raw_valid_dataset[1])

            valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, collate_fn=generate_predict_batch2)
            train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, collate_fn=generate_predict_batch2, shuffle=True)
            return(train_loader, valid_loader)

    train_loader, valid_loader = sample_documents_experiments(model_id)
        
    ## Get model
    def choose_model(model_id,datapath):

        if model_id=='word2vec':
            with open(datapath+'/word2vec_matrix.npy','rb') as f:
                #directly load pre-generated word2vec matrix from file
                word2vec_matrix=np.load(f)
            return PredictModel_word2vec(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim, dropout_p=dropout_p, n_layers=n_layers, word2vec_matrix=word2vec_matrix)
        if model_id=='contrastive':
            return ContrastiveModel(vocab_size=vocab_size, embed_dim=embed_dim, c_dim=c_dim, h_dim=h_dim, dropout_p=dropout_p, n_layers=n_layers)



        if model_id=='word2vec_train':
            with open(datapath+'/word2vec_matrix_trained.npy','rb') as f:
                word2vec_matrix=np.load(f)
            return PredictModel_word2vec(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim, dropout_p=dropout_p, n_layers=n_layers, word2vec_matrix=word2vec_matrix)


        if model_id=='lastlayer':
            return PredictModel_lastlayer(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim,dropout_p=dropout_p, n_layers=n_layers)
        if model_id=='base':
            return PredictModel_base(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim,dropout_p=dropout_p, n_layers=n_layers)

        if model_id=='base_rsm':
            rs_matrix=np.random.normal(0, math.sqrt(1/h_dim), size=(h_dim,h_dim))
            return PredictModel_rsm(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim,dropout_p=dropout_p, n_layers=n_layers,rs_matrix=rs_matrix)
        if model_id=='base_rsm2':
            rs_matrix=np.random.normal(0, math.sqrt(1/h_dim), size=(h_dim,2*h_dim)) # a larger random matrix
            return PredictModel_rsm(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim,dropout_p=dropout_p, n_layers=n_layers,rs_matrix=rs_matrix)

        if model_id=='sec2last':
            return PredictModel_sec2last(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim,representation_dim=representation_dim, dropout_p=dropout_p, n_layers=n_layers)

        if model_id=='attention':
            return PredictModel_attention(embed_dim=embed_dim,vocab_size=vocab_size, h_dim=h_dim, dropout_p=dropout_p, n_layers=n_layers, init_weight=None)
        print("No valid model_id:",model_id)
        return
    

    model = choose_model(model_id,data_path)
 
    ## Load model if it exists
    if(prev_model_file is not None):
        print(prev_model_file)
        model = torch.load(prev_model_file)

    model = model.to(device)
    print(model)
    torch.save(model, temp_file)

    ## Loss function
    #Here 
    if model_id=='contrastive':
        loss_fn = nn.BCEWithLogitsLoss(reduction='sum') 
    elif model_id=='attention':
        loss_fn = nn.CrossEntropyLoss(reduction='sum') 
    else:
        loss_fn = nn.CrossEntropyLoss(reduction='none') 

    ## Scheduler/optimizer
    if(use_scheduler):
        optimizer, scheduler = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler,w_decay=w_decay)
    else:
        optimizer = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)

    scaler = torch.cuda.amp.GradScaler() 
    ## Validation loss + accuracy
    def validation_loss(model_id):
        if model_id=='contrastive':
            model.eval()
            total_loss = 0.0
            total_examples = 0.0
            total_correct = 0.0
            with torch.no_grad():
                for text_1, offset_1, text_2, offset_2, y in valid_loader:
                    n_examples = len(offset_1)
                    outputs = model(text_1.to(device), offset_1.to(device), text_2.to(device), offset_2.to(device)).squeeze(1).cpu()

                    ## Compute loss
                    curr_loss = loss_fn(outputs, y.squeeze())
                    total_loss += curr_loss.item()

                    ## Compute accuracy
                    predictions = torch.where(outputs > 0.5, torch.ones_like(outputs), torch.zeros_like(outputs))
                    total_correct += (predictions == y).sum().item()


                    total_examples += n_examples
            return(total_loss/total_examples, total_correct/total_examples)
        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # added attention
        # only remove offset
        elif model_id == 'attention':
            model.eval()
            total_loss = 0.0
            total_examples = 0.0
            total_correct = 0.0
            with torch.no_grad():
                for text_1, y in valid_loader:
                    n_examples = len(text_1)

                    predictions = model(text_1.to(device)).squeeze(1)
                    
                    loss=torch.zeros(1, dtype=torch.float).to(device) #!!!!!!!!!!!!!!!!!!!!!!!! move to device, not on cpu

                    for i in range(t):
                        loss += loss_fn(predictions,y[:,i].to(device))
                    loss=loss/t
                    #loss=loss.sum()
                    ## Compute loss

                    total_loss += loss.item()

                    ## Accuracy is not computed
                    total_correct =0

                    total_examples += n_examples
            return(total_loss/total_examples, total_correct/total_examples)

        else:
            model.eval()
            total_loss = 0.0
            total_examples = 0.0
            total_correct = 0.0
            with torch.no_grad():
                for text_1, offset_1, y in valid_loader:
                    n_examples = len(offset_1)

                    predictions = model(text_1.to(device), offset_1.to(device)).squeeze(1).cpu()
                    
                    loss=torch.zeros(len(y), dtype=torch.float)

                    for i in range(t):
                        loss += loss_fn(predictions,torch.tensor([x[i] for x in y]))
                    loss=loss/t
                    loss=loss.sum()
                    ## Compute loss

                    total_loss += loss.item()

                    ## Accuracy is not computed
                    total_correct =0

                    total_examples += n_examples
            return(total_loss/total_examples, total_correct/total_examples)

    ## Training step
    def train_step(model_id):
        if model_id=='contrastive':
            CLIP_NORM = 25.0
            model.train()
            train_loss = 0.0
            total_examples = 0.0
            for text_1, offset_1, text_2, offset_2, y in train_loader:
                optimizer.zero_grad()
                model.zero_grad()

                n_examples = len(offset_1)
                predictions = model(text_1.to(device), offset_1.to(device), text_2.to(device), offset_2.to(device)).squeeze(1)
                loss = loss_fn(predictions, y.squeeze().to(device))

                train_loss += loss.item()
                total_examples += n_examples

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
                optimizer.step()

            return(train_loss/total_examples)


        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # added attention
        # only remove offset
        elif model_id == 'attention':
            CLIP_NORM = 25.0
            model.train()
            train_loss = 0.0
            total_examples = 0.0
            model.zero_grad(set_to_none=True)

            for idx,(text_1, y) in enumerate(train_loader):
                n_examples = len(text_1)
                for i in range(t):

                    predictions = model(text_1.to(device)).squeeze(1)
                    loss = loss_fn(predictions,y[:,i].to(device))/t


                    train_loss += loss.item()
                    #scaler.scale(loss).backward()
                    loss.backward()
                total_examples += n_examples


                if (idx + 1) % (128//BATCH_SIZE) == 0:
                   # scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
                   # scaler.step(optimizer)
                   # scaler.update()
                    optimizer.step()
                    model.zero_grad(set_to_none=True)
                #    loss=torch.zeros(1, dtype=torch.float).to(device)

            return(train_loss/total_examples)

        else:
            CLIP_NORM = 25.0
            model.train()
            train_loss = 0.0
            total_examples = 0.0
            for text_1, offset_1, y in train_loader:
                optimizer.zero_grad()
                model.zero_grad()

                n_examples = len(offset_1)
                predictions = model(text_1.to(device), offset_1.to(device)).squeeze(1)
                
                loss=torch.zeros(len(y), dtype=torch.float).to(device)
                
                for i in range(t):
                    loss += loss_fn(predictions,torch.tensor([x[i] for x in y]).to(device))
                loss=loss/t
                loss=loss.sum()

                train_loss += loss.item()
                total_examples += n_examples

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
                optimizer.step()

            return(train_loss/total_examples)

    valid_loss_list = []
    valid_loss, valid_acc = validation_loss(model_id)
    best_valid_loss = valid_loss

    ## Define save function
    def save_stats(epoch):
        best_model = torch.load(temp_file)
        best_model.to(device)

        folder = os.path.join(results_folder, "_".join(["epoch", str(epoch)]))
        if(not os.path.isdir(folder)):
            os.mkdir(folder)
        
        rand_tag = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)))
        # test_name = os.path.join(folder, "_".join([rand_tag, "test.npy"]))
        # train_name = os.path.join(folder, "_".join([rand_tag, "train.npy"]))
        loss_name = os.path.join(folder, "_".join([rand_tag, "valid_loss.npy"]))
        model_name = os.path.join(folder, "_".join([rand_tag, "model.pt"]))

        # np.save(test_name, X_test.numpy())
        # np.save(train_name, X_train.numpy())
        np.save(loss_name, np.array(valid_loss_list))
        # torch.save(best_model, model_name)

        del valid_loss_list[:]

        

    model_param_string = "n_layers_" + str(n_layers) + "_emb_dim_" + str(embed_dim) + "_c_dim_" + str(c_dim) + "_h_dim_" + str(h_dim)
    # writer = SummaryWriter(comment=model_param_string)
    # writer.add_scalar('Loss/valid', valid_loss, 0)
    # writer.add_scalar('Accuracy/valid', valid_acc, 0)

    print("Validation loss: ", valid_loss)
    print("Validation accuracy: ", valid_acc)
    valid_loss_list.append(valid_loss)
    for epoch in range(1, nepochs+1):
        train_loss = train_step(model_id)
        valid_loss, valid_acc = validation_loss(model_id)


        valid_loss_list.append(valid_loss)

        # writer.add_scalar('Loss/train', train_loss, epoch)
        # writer.add_scalar('Loss/valid', valid_loss, epoch)
        # writer.add_scalar('Accuracy/valid', valid_acc, epoch)
        if(use_scheduler):
            scheduler.step(valid_loss)
            # writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch)
        
        if(valid_loss < best_valid_loss and temp_file):
            best_valid_loss = valid_loss
            ## Save out the current model
            torch.save(model, temp_file)

        if((resample > 0) and (epoch < nepochs) and ((epoch % resample) == 0)):
            ## Resample documents
            train_loader, valid_loader = sample_documents_experiments(model_id)
        print(valid_loss)
        print("epoch:")
        print(epoch)
        ## Compute embeddings
        # if((epoch % save_freq) == 0):
        #     save_stats(epoch)

    # if(nepochs % save_freq != 0):
    save_stats(nepochs) ## One last time
    print("last valid accu:")
    print(valid_acc)
    print("best valid loss:")
    print(best_valid_loss)
    model = torch.load(temp_file)
    os.remove(temp_file)
    print("===============")
    print("END Unsupervised Learning")
    print("===============")
    return (model)
    


'''
SUPERVISED Phase

Get model embeddings and run supervised learning and CV on supervised dataset and evaluate performance
on test set

word2vec_emb_size is only useful when embedding_type = word2vec
'''
def train_classifier(dataset = "agnews",model=None,n_samples=4000,embedding_type='model',\
    contrastive=False,data_path=None,word2vec_emb_size=300,lda_emb_size=50,word2vec_matrix = None,\
    random = False,custom_embedding = False):

    data_path = 'data/data_'+dataset
    vocab, train, test, unsup, valid = get_data(data_path)
    print(f"train_classifer getting data from {dataset}")
    ## Get data
    train_tokens = train['tokens']
    train_counts = train['counts']

    test_tokens = test['tokens']
    test_counts = test['counts']

    ## Get Y Labels
    _,Y_train=read_csv(f'data/data_{dataset}/train.csv')
    _,Y_test=read_csv(f'data/data_{dataset}/test.csv')

    ## Get Embeddings
    if embedding_type=='model':
        if not model:
            print("ERROR: No model as input ")
        elif contrastive==True:
            #default datapath="data"
            X_train,X_test=extract_embedding(data_path,model)
        elif isinstance(model, PredictModel_attention):
            X_train,Y_train = model_embeddings(train_tokens, train_counts, model,Y_train)
            X_test,Y_test = model_embeddings(test_tokens, test_counts, model,Y_test)
  
        else:
            X_train = model_embeddings(train_tokens, train_counts, model,custom_embedding = custom_embedding)
            X_test = model_embeddings(test_tokens, test_counts, model,custom_embedding=custom_embedding)
            ##DEBUG
            # print("bp1: print X_train: ",X_train)
            print("bp1: print X_train shape:",X_train.shape)
        

    elif embedding_type=='BOW':
        X_train = BOW_embeddings(data_path,train_tokens, train_counts)
        X_test = BOW_embeddings(data_path,test_tokens, test_counts)

    elif embedding_type=='word2vec':
        
        X_train = Word2Vec_embeddings(data_path,train_tokens, train_counts,word2vec_emb_size,word2vec_matrix,random = random)
        X_test = Word2Vec_embeddings(data_path,test_tokens, test_counts,word2vec_emb_size,word2vec_matrix,random = random)
    
    elif embedding_type=='lda':
        
        X_train = LDA_embedding(data_path,train_tokens, train_counts,lda_emb_size)
        X_test = LDA_embedding(data_path,test_tokens, test_counts,lda_emb_size)
    
    else:
        print("No known Embedding type specified...")
        

    # ## Get Y Labels
    # _,Y_train=read_csv('data/train.csv')
    # _,Y_test=read_csv('data/test.csv')


    ## Select n_samples out of 4000 training samples
    inds=np.random.choice(len(X_train),size=n_samples,replace=False)
    X_train=[X_train[i] for i in range(len(X_train)) if i in inds]
    Y_train=[Y_train[i] for i in range(len(Y_train)) if i in inds]
    print("the representation dimension is: !!!!!!!!!!")
    print(len(X_train[0]))
    

    #CV for linear classifier w/ Logistic Regression
    cv_startime=time.time()

    logistics=LogisticRegression(multi_class='multinomial')
    C=np.logspace(-3,3,10) # 0 4 10
    solver=['lbfgs', 'sag', 'saga', 'newton-cg' ]
    hyperparamters=dict(C=C,solver=solver)

    print("starting grid search fitting...")
    clf=GridSearchCV(logistics,hyperparamters,cv=3,verbose=1)
    best_model=clf.fit(X_train,Y_train)
    best_C=best_model.best_estimator_.get_params()['C']
    best_solver=best_model.best_estimator_.get_params()['solver']
    print('Best c:',best_C)
    print('Best sovler',best_solver)

    cv_endtime=time.time()
    cv_duration=str(timedelta(seconds=cv_endtime - cv_startime))

    print('Evaluating test using best model')
   
    model=LogisticRegression(C=best_C,multi_class='multinomial',solver=best_solver)
    model.fit(X_train,Y_train)
    accu=(np.array(model.predict(X_test))==np.array(Y_test)).sum()/len(Y_test)
    print(accu)

    return (best_C,best_solver,accu,cv_duration)


